# Ansible playbook to setup inference server

- name: Deploy inference
  hosts: dev
  gather_facts: true
  vars:
    stack_name: "dev"
    image_tag: latest
    server_port: 8080
  tasks:
    - name: Create network
      community.docker.docker_network:
        name: "oasst-{{ stack_name }}"
        state: present
        driver: bridge

    - name: Create stack files directory
      ansible.builtin.file:
        path: "./{{ stack_name }}-inference"
        state: directory
        mode: 0755

    - name: Copy redis.conf to managed node
      ansible.builtin.copy:
        src: ./redis.conf
        dest: "./{{ stack_name }}-inference/redis.conf"
        mode: 0644

    - name: Set up inference Redis
      community.docker.docker_container:
        name: "oasst-inference-{{ stack_name }}-redis"
        image: redis
        state: started
        recreate: "{{ (stack_name == 'dev') | bool }}"
        restart_policy: always
        network_mode: "oasst-{{ stack_name }}"
        healthcheck:
          test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
          interval: 2s
          timeout: 2s
          retries: 10
        command: redis-server /usr/local/etc/redis/redis.conf
        volumes:
          - "./{{ stack_name
            }}-inference/redis.conf:/usr/local/etc/redis/redis.conf"

    - name: Create volumes for inference postgres
      community.docker.docker_volume:
        name: "oasst-inference-{{ stack_name }}-postgres"
        state: present

    - name: Create postgres containers
      community.docker.docker_container:
        name: "oasst-inference-{{ stack_name }}-postgres"
        image: ghcr.io/laion-ai/open-assistant/oasst-postgres
        state: started
        pull: true
        recreate: "{{ (stack_name == 'dev') | bool }}"
        restart_policy: always
        network_mode: "oasst-{{ stack_name }}"
        env:
          POSTGRES_USER: postgres
          POSTGRES_PASSWORD:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_POSTGRES_PASSWORD') |
            default('postgres', true) }}"
          POSTGRES_DB: postgres
          S3_BUCKET_NAME:
            "{{ lookup('ansible.builtin.env', 'S3_BUCKET_NAME') }}"
          S3_PREFIX: "inference"
          AWS_ACCESS_KEY_ID:
            "{{ lookup('ansible.builtin.env', 'AWS_ACCESS_KEY') }}"
          AWS_SECRET_ACCESS_KEY:
            "{{ lookup('ansible.builtin.env', 'AWS_SECRET_KEY') }}"
          AWS_DEFAULT_REGION: "{{ lookup('ansible.builtin.env', 'S3_REGION') }}"
        volumes:
          - "oasst-inference-{{ stack_name }}-postgres:/var/lib/postgresql/data"
        healthcheck:
          test: ["CMD", "pg_isready", "-U", "postgres"]
          interval: 2s
          timeout: 2s
          retries: 10
        shm_size: 1G

    - name: Run the oasst inference-server
      community.docker.docker_container:
        name: "oasst-inference-{{ stack_name }}-server"
        image:
          "ghcr.io/laion-ai/open-assistant/oasst-inference-server:{{ image_tag
          }}"
        state: started
        recreate: true
        pull: true
        restart_policy: always
        network_mode: "oasst-{{ stack_name }}"
        env:
          POSTGRES_HOST: "oasst-inference-{{ stack_name }}-postgres"
          POSTGRES_PASSWORD:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_POSTGRES_PASSWORD') |
            default('postgres', true) }}"
          REDIS_HOST: "oasst-inference-{{ stack_name }}-redis"
          LOGURU_LEVEL:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_LOG_LEVEL') |
            default('INFO', true) }}"
          DEBUG_API_KEYS:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_DEBUG_API_KEYS') |
            default('', true) | string }}"
          ALLOW_DEBUG_AUTH:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_ALLOW_DEBUG_AUTH') |
            default('False', true) | string }}"
          ROOT_TOKEN:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_ROOT_TOKEN') |
            default('1234', true) }}"
          API_ROOT:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_API_ROOT') |
            default('https://inference.dev.open-assistant.io', true) }}"
          TRUSTED_CLIENT_KEYS:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_TRUSTED_CLIENT_KEYS') |
            default('', true) }}"
          AUTH_SALT:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_AUTH_SALT') |
            default('', true) }}"
          AUTH_SECRET:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_AUTH_SECRET') |
            default('', true) }}"
          AUTH_DISCORD_CLIENT_ID:
            "{{ lookup('ansible.builtin.env',
            'INFERENCE_AUTH_DISCORD_CLIENT_ID') | default('', true) }}"
          AUTH_DISCORD_CLIENT_SECRET:
            "{{ lookup('ansible.builtin.env',
            'INFERENCE_AUTH_DISCORD_CLIENT_SECRET') | default('', true) }}"
          AUTH_GITHUB_CLIENT_ID:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_AUTH_GITHUB_CLIENT_ID')
            | default('', true) }}"
          AUTH_GITHUB_CLIENT_SECRET:
            "{{ lookup('ansible.builtin.env',
            'INFERENCE_AUTH_GITHUB_CLIENT_SECRET') | default('', true) }}"
          INFERENCE_CORS_ORIGINS:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_CORS_ORIGINS') |
            default('*', true) }}"
          ALLOWED_MODEL_CONFIG_NAMES:
            "{{ lookup('ansible.builtin.env',
            'INFERENCE_ALLOWED_MODEL_CONFIG_NAMES') | default('*', true) }}"
          ASSISTANT_MESSAGE_TIMEOUT:
            "{{ lookup('ansible.builtin.env',
            'INFERENCE_ASSISTANT_MESSAGE_TIMEOUT') | default(120, true) }}"
          MESSAGE_QUEUE_EXPIRE:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_MESSAGE_QUEUE_EXPIRE')
            | default(120, true) }}"
          WORK_QUEUE_MAX_SIZE:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_WORK_QUEUE_MAX_SIZE') |
            default(100, true) }}"
          ENABLE_SAFETY:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_ENABLE_SAFETY') |
            default('False', true) | string }}"
          GUNICORN_WORKERS:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_GUNICORN_WORKERS') |
            default(1, true) }}"
          CHAT_MAX_MESSAGES:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_CHAT_MAX_MESSAGES') |
            default('', true) }}"
          MESSAGE_MAX_LENGTH:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_MESSAGE_MAX_LENGTH') |
            default('', true) }}"
          PLUGIN_MAX_DEPTH:
            "{{ lookup('ansible.builtin.env', 'INFERENCE_PLUGIN_MAX_DEPTH') |
            default(4, true) }}"
        ports:
          - "{{ server_port }}:8080"
